<!DOCTYPE html>
<HTML lang = "en">
<HEAD>
  <meta charset="UTF-8"/>
  <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
  <title>Differentiable Programming and Neural Differential Equations</title>
  

  <script type="text/x-mathjax-config">
    MathJax.Hub.Config({
      tex2jax: {inlineMath: [['$','$'], ['\\(','\\)']]},
      TeX: { equationNumbers: { autoNumber: "AMS" } }
    });
  </script>

  <script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML">
  </script>

  
<style>
pre.hljl {
    border: 1px solid #ccc;
    margin: 5px;
    padding: 5px;
    overflow-x: auto;
    color: rgb(68,68,68); background-color: rgb(251,251,251); }
pre.hljl > span.hljl-t { }
pre.hljl > span.hljl-w { }
pre.hljl > span.hljl-e { }
pre.hljl > span.hljl-eB { }
pre.hljl > span.hljl-o { }
pre.hljl > span.hljl-k { color: rgb(148,91,176); font-weight: bold; }
pre.hljl > span.hljl-kc { color: rgb(59,151,46); font-style: italic; }
pre.hljl > span.hljl-kd { color: rgb(214,102,97); font-style: italic; }
pre.hljl > span.hljl-kn { color: rgb(148,91,176); font-weight: bold; }
pre.hljl > span.hljl-kp { color: rgb(148,91,176); font-weight: bold; }
pre.hljl > span.hljl-kr { color: rgb(148,91,176); font-weight: bold; }
pre.hljl > span.hljl-kt { color: rgb(148,91,176); font-weight: bold; }
pre.hljl > span.hljl-n { }
pre.hljl > span.hljl-na { }
pre.hljl > span.hljl-nb { }
pre.hljl > span.hljl-nbp { }
pre.hljl > span.hljl-nc { }
pre.hljl > span.hljl-ncB { }
pre.hljl > span.hljl-nd { color: rgb(214,102,97); }
pre.hljl > span.hljl-ne { }
pre.hljl > span.hljl-neB { }
pre.hljl > span.hljl-nf { color: rgb(66,102,213); }
pre.hljl > span.hljl-nfm { color: rgb(66,102,213); }
pre.hljl > span.hljl-np { }
pre.hljl > span.hljl-nl { }
pre.hljl > span.hljl-nn { }
pre.hljl > span.hljl-no { }
pre.hljl > span.hljl-nt { }
pre.hljl > span.hljl-nv { }
pre.hljl > span.hljl-nvc { }
pre.hljl > span.hljl-nvg { }
pre.hljl > span.hljl-nvi { }
pre.hljl > span.hljl-nvm { }
pre.hljl > span.hljl-l { }
pre.hljl > span.hljl-ld { color: rgb(148,91,176); font-style: italic; }
pre.hljl > span.hljl-s { color: rgb(201,61,57); }
pre.hljl > span.hljl-sa { color: rgb(201,61,57); }
pre.hljl > span.hljl-sb { color: rgb(201,61,57); }
pre.hljl > span.hljl-sc { color: rgb(201,61,57); }
pre.hljl > span.hljl-sd { color: rgb(201,61,57); }
pre.hljl > span.hljl-sdB { color: rgb(201,61,57); }
pre.hljl > span.hljl-sdC { color: rgb(201,61,57); }
pre.hljl > span.hljl-se { color: rgb(59,151,46); }
pre.hljl > span.hljl-sh { color: rgb(201,61,57); }
pre.hljl > span.hljl-si { }
pre.hljl > span.hljl-so { color: rgb(201,61,57); }
pre.hljl > span.hljl-sr { color: rgb(201,61,57); }
pre.hljl > span.hljl-ss { color: rgb(201,61,57); }
pre.hljl > span.hljl-ssB { color: rgb(201,61,57); }
pre.hljl > span.hljl-nB { color: rgb(59,151,46); }
pre.hljl > span.hljl-nbB { color: rgb(59,151,46); }
pre.hljl > span.hljl-nfB { color: rgb(59,151,46); }
pre.hljl > span.hljl-nh { color: rgb(59,151,46); }
pre.hljl > span.hljl-ni { color: rgb(59,151,46); }
pre.hljl > span.hljl-nil { color: rgb(59,151,46); }
pre.hljl > span.hljl-noB { color: rgb(59,151,46); }
pre.hljl > span.hljl-oB { color: rgb(102,102,102); font-weight: bold; }
pre.hljl > span.hljl-ow { color: rgb(102,102,102); font-weight: bold; }
pre.hljl > span.hljl-p { }
pre.hljl > span.hljl-c { color: rgb(153,153,119); font-style: italic; }
pre.hljl > span.hljl-ch { color: rgb(153,153,119); font-style: italic; }
pre.hljl > span.hljl-cm { color: rgb(153,153,119); font-style: italic; }
pre.hljl > span.hljl-cp { color: rgb(153,153,119); font-style: italic; }
pre.hljl > span.hljl-cpB { color: rgb(153,153,119); font-style: italic; }
pre.hljl > span.hljl-cs { color: rgb(153,153,119); font-style: italic; }
pre.hljl > span.hljl-csB { color: rgb(153,153,119); font-style: italic; }
pre.hljl > span.hljl-g { }
pre.hljl > span.hljl-gd { }
pre.hljl > span.hljl-ge { }
pre.hljl > span.hljl-geB { }
pre.hljl > span.hljl-gh { }
pre.hljl > span.hljl-gi { }
pre.hljl > span.hljl-go { }
pre.hljl > span.hljl-gp { }
pre.hljl > span.hljl-gs { }
pre.hljl > span.hljl-gsB { }
pre.hljl > span.hljl-gt { }
</style>



  <style type="text/css">
  @font-face {
  font-style: normal;
  font-weight: 300;
}
@font-face {
  font-style: normal;
  font-weight: 400;
}
@font-face {
  font-style: normal;
  font-weight: 600;
}
html {
  font-family: sans-serif; /* 1 */
  -ms-text-size-adjust: 100%; /* 2 */
  -webkit-text-size-adjust: 100%; /* 2 */
}
body {
  margin: 0;
}
article,
aside,
details,
figcaption,
figure,
footer,
header,
hgroup,
main,
menu,
nav,
section,
summary {
  display: block;
}
audio,
canvas,
progress,
video {
  display: inline-block; /* 1 */
  vertical-align: baseline; /* 2 */
}
audio:not([controls]) {
  display: none;
  height: 0;
}
[hidden],
template {
  display: none;
}
a:active,
a:hover {
  outline: 0;
}
abbr[title] {
  border-bottom: 1px dotted;
}
b,
strong {
  font-weight: bold;
}
dfn {
  font-style: italic;
}
h1 {
  font-size: 2em;
  margin: 0.67em 0;
}
mark {
  background: #ff0;
  color: #000;
}
small {
  font-size: 80%;
}
sub,
sup {
  font-size: 75%;
  line-height: 0;
  position: relative;
  vertical-align: baseline;
}
sup {
  top: -0.5em;
}
sub {
  bottom: -0.25em;
}
img {
  border: 0;
}
svg:not(:root) {
  overflow: hidden;
}
figure {
  margin: 1em 40px;
}
hr {
  -moz-box-sizing: content-box;
  box-sizing: content-box;
  height: 0;
}
pre {
  overflow: auto;
}
code,
kbd,
pre,
samp {
  font-family: monospace, monospace;
  font-size: 1em;
}
button,
input,
optgroup,
select,
textarea {
  color: inherit; /* 1 */
  font: inherit; /* 2 */
  margin: 0; /* 3 */
}
button {
  overflow: visible;
}
button,
select {
  text-transform: none;
}
button,
html input[type="button"], /* 1 */
input[type="reset"],
input[type="submit"] {
  -webkit-appearance: button; /* 2 */
  cursor: pointer; /* 3 */
}
button[disabled],
html input[disabled] {
  cursor: default;
}
button::-moz-focus-inner,
input::-moz-focus-inner {
  border: 0;
  padding: 0;
}
input {
  line-height: normal;
}
input[type="checkbox"],
input[type="radio"] {
  box-sizing: border-box; /* 1 */
  padding: 0; /* 2 */
}
input[type="number"]::-webkit-inner-spin-button,
input[type="number"]::-webkit-outer-spin-button {
  height: auto;
}
input[type="search"] {
  -webkit-appearance: textfield; /* 1 */
  -moz-box-sizing: content-box;
  -webkit-box-sizing: content-box; /* 2 */
  box-sizing: content-box;
}
input[type="search"]::-webkit-search-cancel-button,
input[type="search"]::-webkit-search-decoration {
  -webkit-appearance: none;
}
fieldset {
  border: 1px solid #c0c0c0;
  margin: 0 2px;
  padding: 0.35em 0.625em 0.75em;
}
legend {
  border: 0; /* 1 */
  padding: 0; /* 2 */
}
textarea {
  overflow: auto;
}
optgroup {
  font-weight: bold;
}
table {
  font-family: monospace, monospace;
  font-size : 0.8em;
  border-collapse: collapse;
  border-spacing: 0;
}
td,
th {
  padding: 0;
}
thead th {
    border-bottom: 1px solid black;
    background-color: white;
}
tr:nth-child(odd){
  background-color: rgb(248,248,248);
}


/*
* Skeleton V2.0.4
* Copyright 2014, Dave Gamache
* www.getskeleton.com
* Free to use under the MIT license.
* http://www.opensource.org/licenses/mit-license.php
* 12/29/2014
*/
.container {
  position: relative;
  width: 100%;
  max-width: 960px;
  margin: 0 auto;
  padding: 0 20px;
  box-sizing: border-box; }
.column,
.columns {
  width: 100%;
  float: left;
  box-sizing: border-box; }
@media (min-width: 400px) {
  .container {
    width: 85%;
    padding: 0; }
}
@media (min-width: 550px) {
  .container {
    width: 80%; }
  .column,
  .columns {
    margin-left: 4%; }
  .column:first-child,
  .columns:first-child {
    margin-left: 0; }

  .one.column,
  .one.columns                    { width: 4.66666666667%; }
  .two.columns                    { width: 13.3333333333%; }
  .three.columns                  { width: 22%;            }
  .four.columns                   { width: 30.6666666667%; }
  .five.columns                   { width: 39.3333333333%; }
  .six.columns                    { width: 48%;            }
  .seven.columns                  { width: 56.6666666667%; }
  .eight.columns                  { width: 65.3333333333%; }
  .nine.columns                   { width: 74.0%;          }
  .ten.columns                    { width: 82.6666666667%; }
  .eleven.columns                 { width: 91.3333333333%; }
  .twelve.columns                 { width: 100%; margin-left: 0; }

  .one-third.column               { width: 30.6666666667%; }
  .two-thirds.column              { width: 65.3333333333%; }

  .one-half.column                { width: 48%; }

  /* Offsets */
  .offset-by-one.column,
  .offset-by-one.columns          { margin-left: 8.66666666667%; }
  .offset-by-two.column,
  .offset-by-two.columns          { margin-left: 17.3333333333%; }
  .offset-by-three.column,
  .offset-by-three.columns        { margin-left: 26%;            }
  .offset-by-four.column,
  .offset-by-four.columns         { margin-left: 34.6666666667%; }
  .offset-by-five.column,
  .offset-by-five.columns         { margin-left: 43.3333333333%; }
  .offset-by-six.column,
  .offset-by-six.columns          { margin-left: 52%;            }
  .offset-by-seven.column,
  .offset-by-seven.columns        { margin-left: 60.6666666667%; }
  .offset-by-eight.column,
  .offset-by-eight.columns        { margin-left: 69.3333333333%; }
  .offset-by-nine.column,
  .offset-by-nine.columns         { margin-left: 78.0%;          }
  .offset-by-ten.column,
  .offset-by-ten.columns          { margin-left: 86.6666666667%; }
  .offset-by-eleven.column,
  .offset-by-eleven.columns       { margin-left: 95.3333333333%; }

  .offset-by-one-third.column,
  .offset-by-one-third.columns    { margin-left: 34.6666666667%; }
  .offset-by-two-thirds.column,
  .offset-by-two-thirds.columns   { margin-left: 69.3333333333%; }

  .offset-by-one-half.column,
  .offset-by-one-half.columns     { margin-left: 52%; }

}
html {
  font-size: 62.5%; }
body {
  font-size: 1.5em; /* currently ems cause chrome bug misinterpreting rems on body element */
  line-height: 1.6;
  font-weight: 400;
  font-family: "Raleway", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif;
  color: #222; }
h1, h2, h3, h4, h5, h6 {
  margin-top: 0;
  margin-bottom: 2rem;
  font-weight: 300; }
h1 { font-size: 3.6rem; line-height: 1.2;  letter-spacing: -.1rem;}
h2 { font-size: 3.4rem; line-height: 1.25; letter-spacing: -.1rem; }
h3 { font-size: 3.2rem; line-height: 1.3;  letter-spacing: -.1rem; }
h4 { font-size: 2.8rem; line-height: 1.35; letter-spacing: -.08rem; }
h5 { font-size: 2.4rem; line-height: 1.5;  letter-spacing: -.05rem; }
h6 { font-size: 1.5rem; line-height: 1.6;  letter-spacing: 0; }

p {
  margin-top: 0; }
a {
  color: #1EAEDB; }
a:hover {
  color: #0FA0CE; }
.button,
button,
input[type="submit"],
input[type="reset"],
input[type="button"] {
  display: inline-block;
  height: 38px;
  padding: 0 30px;
  color: #555;
  text-align: center;
  font-size: 11px;
  font-weight: 600;
  line-height: 38px;
  letter-spacing: .1rem;
  text-transform: uppercase;
  text-decoration: none;
  white-space: nowrap;
  background-color: transparent;
  border-radius: 4px;
  border: 1px solid #bbb;
  cursor: pointer;
  box-sizing: border-box; }
.button:hover,
button:hover,
input[type="submit"]:hover,
input[type="reset"]:hover,
input[type="button"]:hover,
.button:focus,
button:focus,
input[type="submit"]:focus,
input[type="reset"]:focus,
input[type="button"]:focus {
  color: #333;
  border-color: #888;
  outline: 0; }
.button.button-primary,
button.button-primary,
input[type="submit"].button-primary,
input[type="reset"].button-primary,
input[type="button"].button-primary {
  color: #FFF;
  background-color: #33C3F0;
  border-color: #33C3F0; }
.button.button-primary:hover,
button.button-primary:hover,
input[type="submit"].button-primary:hover,
input[type="reset"].button-primary:hover,
input[type="button"].button-primary:hover,
.button.button-primary:focus,
button.button-primary:focus,
input[type="submit"].button-primary:focus,
input[type="reset"].button-primary:focus,
input[type="button"].button-primary:focus {
  color: #FFF;
  background-color: #1EAEDB;
  border-color: #1EAEDB; }
input[type="email"],
input[type="number"],
input[type="search"],
input[type="text"],
input[type="tel"],
input[type="url"],
input[type="password"],
textarea,
select {
  height: 38px;
  padding: 6px 10px; /* The 6px vertically centers text on FF, ignored by Webkit */
  background-color: #fff;
  border: 1px solid #D1D1D1;
  border-radius: 4px;
  box-shadow: none;
  box-sizing: border-box; }
/* Removes awkward default styles on some inputs for iOS */
input[type="email"],
input[type="number"],
input[type="search"],
input[type="text"],
input[type="tel"],
input[type="url"],
input[type="password"],
textarea {
  -webkit-appearance: none;
     -moz-appearance: none;
          appearance: none; }
textarea {
  min-height: 65px;
  padding-top: 6px;
  padding-bottom: 6px; }
input[type="email"]:focus,
input[type="number"]:focus,
input[type="search"]:focus,
input[type="text"]:focus,
input[type="tel"]:focus,
input[type="url"]:focus,
input[type="password"]:focus,
textarea:focus,
select:focus {
  border: 1px solid #33C3F0;
  outline: 0; }
label,
legend {
  display: block;
  margin-bottom: .5rem;
  font-weight: 600; }
fieldset {
  padding: 0;
  border-width: 0; }
input[type="checkbox"],
input[type="radio"] {
  display: inline; }
label > .label-body {
  display: inline-block;
  margin-left: .5rem;
  font-weight: normal; }
ul {
  list-style: circle; }
ol {
  list-style: decimal; }
ul ul,
ul ol,
ol ol,
ol ul {
  margin: 1.5rem 0 1.5rem 3rem;
  font-size: 90%; }
li > p {margin : 0;}
th,
td {
  padding: 12px 15px;
  text-align: left;
  border-bottom: 1px solid #E1E1E1; }
th:first-child,
td:first-child {
  padding-left: 0; }
th:last-child,
td:last-child {
  padding-right: 0; }
button,
.button {
  margin-bottom: 1rem; }
input,
textarea,
select,
fieldset {
  margin-bottom: 1.5rem; }
pre,
blockquote,
dl,
figure,
table,
p,
ul,
ol,
form {
  margin-bottom: 1.0rem; }
.u-full-width {
  width: 100%;
  box-sizing: border-box; }
.u-max-full-width {
  max-width: 100%;
  box-sizing: border-box; }
.u-pull-right {
  float: right; }
.u-pull-left {
  float: left; }
hr {
  margin-top: 3rem;
  margin-bottom: 3.5rem;
  border-width: 0;
  border-top: 1px solid #E1E1E1; }
.container:after,
.row:after,
.u-cf {
  content: "";
  display: table;
  clear: both; }

pre {
  display: block;
  padding: 9.5px;
  margin: 0 0 10px;
  font-size: 13px;
  line-height: 1.42857143;
  word-break: break-all;
  word-wrap: break-word;
  border: 1px solid #ccc;
  border-radius: 4px;
}

pre.hljl {
  margin: 0 0 10px;
  display: block;
  background: #f5f5f5;
  border-radius: 4px;
  padding : 5px;
}

pre.output {
  background: #ffffff;
}

pre.code {
  background: #ffffff;
}

pre.julia-error {
  color : red
}

code,
kbd,
pre,
samp {
  font-family: Menlo, Monaco, Consolas, "Courier New", monospace;
  font-size: 0.9em;
}


@media (min-width: 400px) {}
@media (min-width: 550px) {}
@media (min-width: 750px) {}
@media (min-width: 1000px) {}
@media (min-width: 1200px) {}

h1.title {margin-top : 20px}
img {max-width : 100%}
div.title {text-align: center;}

  </style>
</HEAD>

<BODY>
  <div class ="container">
    <div class = "row">
      <div class = "col-md-12 twelve columns">
        <div class="title">
          <h1 class="title">Differentiable Programming and Neural Differential Equations</h1>
          <h5>Chris Rackauckas</h5>
          <h5>November 20th, 2020</h5>
        </div>

        <p>Our last discussion focused on how, at a high mathematical level, one could in theory build programs which compute gradients in a fast manner by looking at the computational graph and performing reverse-mode automatic differentiation. Within the context of parameter identification, we saw many advantages to this approach because it did not scale multiplicatively in the number of parameters, and thus it is an efficient way to calculate Jacobians of objects where there are less rows than columns &#40;think of the gradient as 1 row&#41;.</p>
<p>More precisely, this is seen to be more about sparsity patterns, with reverse-mode as being more efficient if there are &quot;enough&quot; less row seeds required than column partials &#40;with mixed mode approaches sometimes being much better&#41;. However, to make reverse-mode AD realistically usable inside of a programming language instead of a compute graph, we need to do three things:</p>
<ol>
<li><p>We need to have a way of implementing reverse-mode AD on a language.</p>
</li>
<li><p>We need a systematic way to derive &quot;adjoint&quot; relationships &#40;pullbacks&#41;.</p>
</li>
<li><p>We need to see if there are better ways to fit parameters to data, rather than performing reverse-mode AD through entire programs&#33;</p>
</li>
</ol>
<h2>Implementation of Reverse-Mode AD</h2>
<p>Forward-mode AD was implementable through operator overloading and dual number arithmetic. However, reverse-mode AD requires reversing a program through its computational structure, which is a much more difficult operation. This begs the question, how does one actually build a reverse-mode AD implementation?</p>
<h3>Static Graph AD</h3>
<p>The most obvious solution is to use a static compute graph, since how we defined our differentiation structure was on a compute graph. Tensorflow is a modern example of this approach, where a user must define variables and operations in a graph language &#40;that&#39;s embedded into Python, R, Julia, etc.&#41;, and then execution on the graph is easy to differentiate. This has the advantage of being a simplified and controlled form, which means that not only differentiation transformations are possible, but also things like automatic parallelization. However, many see directly writing a &#40;static&#41; computation graph as a barrier for practical use since it requires completely rewriting all existing programs to this structure.</p>
<h3>Tracing-Based AD and Wengert Lists</h3>
<p>Recall that an alternative formulation of reverse-mode AD for composed functions</p>
<p class="math">\[
f = f^L \circ f^{L-1} \circ \ldots \circ f^1
\]</p>
<p>is through pullbacks on the Jacobians:</p>
<p class="math">\[
v^T J = (\ldots ((v^T J_L) J_{L-1}) \ldots ) J_1
\]</p>
<p>Therefore, if one can transform the program structure into a list of composed functions, then reverse-mode AD is the successive application of pullbacks going in the reverse direction:</p>
<p class="math">\[
\mathcal{B}_{f}^{x}(A)=\mathcal{B}_{f^{1}}^{x}\left(\ldots\left(\mathcal{\mathcal{B}}_{f^{L-1}}^{f^{L-2}(f^{L-3}(\ldots f^{1}(x)\ldots))}\left(\mathcal{B}_{f^{L}}^{f^{L-1}(f^{L-2}(\ldots f^{1}(x)\ldots))}(A)\right)\right)\ldots\right)
\]</p>
<p>Recall that the pullback <span class="math">$\mathcal{B}_f^x(\overline{y})$</span> requires knowing:</p>
<ol>
<li><p>The operation being performed</p>
</li>
<li><p>The value <span class="math">$x$</span> of the forward pass</p>
</li>
</ol>
<p>The idea is to then build a <em>Wengert list</em> that is from exactly the forward pass of a specific <span class="math">$x$</span>, also known as a <em>trace</em>, and thus giving rise to <em>tracing-based reverse-mode AD</em>. This is the basis of many reverse-mode implementations, such as Julia&#39;s Tracker.jl &#40;Flux.jl&#39;s old AD&#41;, ReverseDiff.jl, PyTorch, Tensorflow Eager, Autograd, and Autograd.jl. It is widely adopted due to its simplicity in implementation.</p>
<h4>Inspecting Tracker.jl</h4>
<p>Tracker.jl is a very simple implementation to inspect. The definition of its number and array types are as follows:</p>


<pre class='hljl'>
<span class='hljl-k'>struct</span><span class='hljl-t'> </span><span class='hljl-nf'>Call</span><span class='hljl-p'>{</span><span class='hljl-n'>F</span><span class='hljl-p'>,</span><span class='hljl-n'>As</span><span class='hljl-oB'>&lt;:</span><span class='hljl-n'>Tuple</span><span class='hljl-p'>}</span><span class='hljl-t'>
  </span><span class='hljl-n'>func</span><span class='hljl-oB'>::</span><span class='hljl-n'>F</span><span class='hljl-t'>
  </span><span class='hljl-n'>args</span><span class='hljl-oB'>::</span><span class='hljl-n'>As</span><span class='hljl-t'>
</span><span class='hljl-k'>end</span><span class='hljl-t'>

</span><span class='hljl-k'>mutable struct</span><span class='hljl-t'> </span><span class='hljl-nf'>Tracked</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>}</span><span class='hljl-t'>
  </span><span class='hljl-n'>ref</span><span class='hljl-oB'>::</span><span class='hljl-n'>UInt32</span><span class='hljl-t'>
  </span><span class='hljl-n'>f</span><span class='hljl-oB'>::</span><span class='hljl-n'>Call</span><span class='hljl-t'>
  </span><span class='hljl-n'>isleaf</span><span class='hljl-oB'>::</span><span class='hljl-n'>Bool</span><span class='hljl-t'>
  </span><span class='hljl-n'>grad</span><span class='hljl-oB'>::</span><span class='hljl-n'>T</span><span class='hljl-t'>
  </span><span class='hljl-nf'>Tracked</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>}(</span><span class='hljl-n'>f</span><span class='hljl-oB'>::</span><span class='hljl-n'>Call</span><span class='hljl-p'>)</span><span class='hljl-t'> </span><span class='hljl-n'>where</span><span class='hljl-t'> </span><span class='hljl-n'>T</span><span class='hljl-t'> </span><span class='hljl-oB'>=</span><span class='hljl-t'> </span><span class='hljl-nf'>new</span><span class='hljl-p'>(</span><span class='hljl-ni'>0</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>f</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-kc'>false</span><span class='hljl-p'>)</span><span class='hljl-t'>
  </span><span class='hljl-nf'>Tracked</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>}(</span><span class='hljl-n'>f</span><span class='hljl-oB'>::</span><span class='hljl-n'>Call</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>grad</span><span class='hljl-oB'>::</span><span class='hljl-n'>T</span><span class='hljl-p'>)</span><span class='hljl-t'> </span><span class='hljl-n'>where</span><span class='hljl-t'> </span><span class='hljl-n'>T</span><span class='hljl-t'> </span><span class='hljl-oB'>=</span><span class='hljl-t'> </span><span class='hljl-nf'>new</span><span class='hljl-p'>(</span><span class='hljl-ni'>0</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>f</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-kc'>false</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>grad</span><span class='hljl-p'>)</span><span class='hljl-t'>
  </span><span class='hljl-nf'>Tracked</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>}(</span><span class='hljl-n'>f</span><span class='hljl-oB'>::</span><span class='hljl-nf'>Call</span><span class='hljl-p'>{</span><span class='hljl-n'>Nothing</span><span class='hljl-p'>},</span><span class='hljl-t'> </span><span class='hljl-n'>grad</span><span class='hljl-oB'>::</span><span class='hljl-n'>T</span><span class='hljl-p'>)</span><span class='hljl-t'> </span><span class='hljl-n'>where</span><span class='hljl-t'> </span><span class='hljl-n'>T</span><span class='hljl-t'> </span><span class='hljl-oB'>=</span><span class='hljl-t'> </span><span class='hljl-nf'>new</span><span class='hljl-p'>(</span><span class='hljl-ni'>0</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>f</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-kc'>true</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>grad</span><span class='hljl-p'>)</span><span class='hljl-t'>
</span><span class='hljl-k'>end</span><span class='hljl-t'>

</span><span class='hljl-k'>mutable struct</span><span class='hljl-t'> </span><span class='hljl-nf'>TrackedReal</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-oB'>&lt;:</span><span class='hljl-n'>Real</span><span class='hljl-p'>}</span><span class='hljl-t'> </span><span class='hljl-oB'>&lt;:</span><span class='hljl-t'> </span><span class='hljl-n'>Real</span><span class='hljl-t'>
  </span><span class='hljl-n'>data</span><span class='hljl-oB'>::</span><span class='hljl-n'>T</span><span class='hljl-t'>
  </span><span class='hljl-n'>tracker</span><span class='hljl-oB'>::</span><span class='hljl-nf'>Tracked</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>}</span><span class='hljl-t'>
</span><span class='hljl-k'>end</span><span class='hljl-t'>

</span><span class='hljl-k'>struct</span><span class='hljl-t'> </span><span class='hljl-nf'>TrackedArray</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>,</span><span class='hljl-n'>N</span><span class='hljl-p'>,</span><span class='hljl-n'>A</span><span class='hljl-oB'>&lt;:</span><span class='hljl-nf'>AbstractArray</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>,</span><span class='hljl-n'>N</span><span class='hljl-p'>}}</span><span class='hljl-t'> </span><span class='hljl-oB'>&lt;:</span><span class='hljl-t'> </span><span class='hljl-nf'>AbstractArray</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>,</span><span class='hljl-n'>N</span><span class='hljl-p'>}</span><span class='hljl-t'>
  </span><span class='hljl-n'>tracker</span><span class='hljl-oB'>::</span><span class='hljl-nf'>Tracked</span><span class='hljl-p'>{</span><span class='hljl-n'>A</span><span class='hljl-p'>}</span><span class='hljl-t'>
  </span><span class='hljl-n'>data</span><span class='hljl-oB'>::</span><span class='hljl-n'>A</span><span class='hljl-t'>
  </span><span class='hljl-n'>grad</span><span class='hljl-oB'>::</span><span class='hljl-n'>A</span><span class='hljl-t'>
  </span><span class='hljl-nf'>TrackedArray</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>,</span><span class='hljl-n'>N</span><span class='hljl-p'>,</span><span class='hljl-n'>A</span><span class='hljl-p'>}(</span><span class='hljl-n'>t</span><span class='hljl-oB'>::</span><span class='hljl-nf'>Tracked</span><span class='hljl-p'>{</span><span class='hljl-n'>A</span><span class='hljl-p'>},</span><span class='hljl-t'> </span><span class='hljl-n'>data</span><span class='hljl-oB'>::</span><span class='hljl-n'>A</span><span class='hljl-p'>)</span><span class='hljl-t'> </span><span class='hljl-n'>where</span><span class='hljl-t'> </span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>,</span><span class='hljl-n'>N</span><span class='hljl-p'>,</span><span class='hljl-n'>A</span><span class='hljl-p'>}</span><span class='hljl-t'> </span><span class='hljl-oB'>=</span><span class='hljl-t'> </span><span class='hljl-nf'>new</span><span class='hljl-p'>(</span><span class='hljl-n'>t</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>data</span><span class='hljl-p'>)</span><span class='hljl-t'>
  </span><span class='hljl-nf'>TrackedArray</span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>,</span><span class='hljl-n'>N</span><span class='hljl-p'>,</span><span class='hljl-n'>A</span><span class='hljl-p'>}(</span><span class='hljl-n'>t</span><span class='hljl-oB'>::</span><span class='hljl-nf'>Tracked</span><span class='hljl-p'>{</span><span class='hljl-n'>A</span><span class='hljl-p'>},</span><span class='hljl-t'> </span><span class='hljl-n'>data</span><span class='hljl-oB'>::</span><span class='hljl-n'>A</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>grad</span><span class='hljl-oB'>::</span><span class='hljl-n'>A</span><span class='hljl-p'>)</span><span class='hljl-t'> </span><span class='hljl-n'>where</span><span class='hljl-t'> </span><span class='hljl-p'>{</span><span class='hljl-n'>T</span><span class='hljl-p'>,</span><span class='hljl-n'>N</span><span class='hljl-p'>,</span><span class='hljl-n'>A</span><span class='hljl-p'>}</span><span class='hljl-t'> </span><span class='hljl-oB'>=</span><span class='hljl-t'> </span><span class='hljl-nf'>new</span><span class='hljl-p'>(</span><span class='hljl-n'>t</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>data</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>grad</span><span class='hljl-p'>)</span><span class='hljl-t'>
</span><span class='hljl-k'>end</span>
</pre>


<p>As expected, it replaces every single number and array with a value that will store not just perform the operation, but also build up a list of operations along with the values at every stage. Then pullback rules are implemented for primitives via the <code>@grad</code> macro. For example, the pullback for the dot product is implemented as:</p>


<pre class='hljl'>
<span class='hljl-nd'>@grad</span><span class='hljl-t'> </span><span class='hljl-nf'>dot</span><span class='hljl-p'>(</span><span class='hljl-n'>xs</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>ys</span><span class='hljl-p'>)</span><span class='hljl-t'> </span><span class='hljl-oB'>=</span><span class='hljl-t'> </span><span class='hljl-nf'>dot</span><span class='hljl-p'>(</span><span class='hljl-nf'>data</span><span class='hljl-p'>(</span><span class='hljl-n'>xs</span><span class='hljl-p'>),</span><span class='hljl-t'> </span><span class='hljl-nf'>data</span><span class='hljl-p'>(</span><span class='hljl-n'>ys</span><span class='hljl-p'>)),</span><span class='hljl-t'> </span><span class='hljl-n'>Δ</span><span class='hljl-t'> </span><span class='hljl-oB'>-&gt;</span><span class='hljl-t'> </span><span class='hljl-p'>(</span><span class='hljl-n'>Δ</span><span class='hljl-t'> </span><span class='hljl-oB'>.*</span><span class='hljl-t'> </span><span class='hljl-n'>ys</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>Δ</span><span class='hljl-t'> </span><span class='hljl-oB'>.*</span><span class='hljl-t'> </span><span class='hljl-n'>xs</span><span class='hljl-p'>)</span>
</pre>


<p>This is read as: the value going forward is computed by using the Julia <code>dot</code> function on the arrays, and the pullback embeds the backs of the forward pass and uses <code>Δ .* ys</code> as the derivative with respect to <code>x</code>, and <code>Δ .* xs</code> as the derivative with respect to <code>y</code>. This element-wise nature makes sense given the diagonal-ness of the Jacobian.</p>
<p>Note that this also allows utilizing intermediates of the forward pass within the reverse pass. This is seen in the definition of the pullback of <code>meanpool</code>:</p>


<pre class='hljl'>
<span class='hljl-nd'>@grad</span><span class='hljl-t'> </span><span class='hljl-k'>function</span><span class='hljl-t'> </span><span class='hljl-nf'>meanpool</span><span class='hljl-p'>(</span><span class='hljl-n'>x</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>pdims</span><span class='hljl-oB'>::</span><span class='hljl-n'>PoolDims</span><span class='hljl-p'>;</span><span class='hljl-t'> </span><span class='hljl-n'>kw</span><span class='hljl-oB'>...</span><span class='hljl-p'>)</span><span class='hljl-t'>
  </span><span class='hljl-n'>y</span><span class='hljl-t'> </span><span class='hljl-oB'>=</span><span class='hljl-t'> </span><span class='hljl-nf'>meanpool</span><span class='hljl-p'>(</span><span class='hljl-nf'>data</span><span class='hljl-p'>(</span><span class='hljl-n'>x</span><span class='hljl-p'>),</span><span class='hljl-t'> </span><span class='hljl-n'>pdims</span><span class='hljl-p'>;</span><span class='hljl-t'> </span><span class='hljl-n'>kw</span><span class='hljl-oB'>...</span><span class='hljl-p'>)</span><span class='hljl-t'>
  </span><span class='hljl-n'>y</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>Δ</span><span class='hljl-t'> </span><span class='hljl-oB'>-&gt;</span><span class='hljl-t'> </span><span class='hljl-p'>(</span><span class='hljl-nf'>nobacksies</span><span class='hljl-p'>(</span><span class='hljl-sc'>:meanpool</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>NNlib</span><span class='hljl-oB'>.</span><span class='hljl-nf'>∇meanpool</span><span class='hljl-p'>(</span><span class='hljl-n'>data</span><span class='hljl-oB'>.</span><span class='hljl-p'>((</span><span class='hljl-n'>Δ</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>y</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>x</span><span class='hljl-p'>))</span><span class='hljl-oB'>...</span><span class='hljl-p'>,</span><span class='hljl-t'> </span><span class='hljl-n'>pdims</span><span class='hljl-p'>;</span><span class='hljl-t'> </span><span class='hljl-n'>kw</span><span class='hljl-oB'>...</span><span class='hljl-p'>)),</span><span class='hljl-t'> </span><span class='hljl-n'>nothing</span><span class='hljl-p'>)</span><span class='hljl-t'>
</span><span class='hljl-k'>end</span>
</pre>


<p>where the derivative makes use of not only <code>x</code>, but also <code>y</code> so that the <code>meanpool</code> does not need to be re-calculated.</p>
<p>Using this style, Tracker.jl moves forward, building up the value and closures for the backpass and then recursively pulls back the input <code>Δ</code> to receive the derivatve.</p>
<h3>Source-to-Source AD</h3>
<p>Given our previous discussions on performance, you should be horrified with how this approach handles scalar values. Each <code>TrackedReal</code> holds as <code>Tracked&#123;T&#125;</code> which holds a <code>Call</code>, not a <code>Call&#123;F,As&lt;:Tuple&#125;</code>, and thus it&#39;s not strictly typed. Because it&#39;s not strictly typed, this implies that every single operation is going to cause heap allocations. If you measure this in PyTorch, TensorFlow Eager, Tracker, etc. you get around 500ns-2ms of overhead. This means that a 2ns <code>&#43;</code> operation becomes... &gt;500ns&#33; Oh my&#33;</p>
<p>This is not the only issue with tracing. Another issue is that the trace is value-dependent, meaning that every new value can build a new trace. Thus one cannot easily JIT compile a trace because it&#39;ll be different for every gradient calculation &#40;you can compile it, but you better make sure the compile times are short&#33;&#41;. Lastly, the Wengert list can be much larger than the code itself. For example, if you trace through a loop that is <code>for i in 1:100000</code>, then the trace will be huge, even if the function is relatively simple. This is directly demonstrated in the JAX &quot;how it works&quot; slide:</p>
<p><img src="https://iaml.it/blog/jax-intro-english/images/lifecycle.png" alt="" /></p>
<p>To avoid these issues, another version of reverse-mode automatic differentiation is <em>source-to-source</em> transformations. In order to do source code transformations, you need to know how to transform all language constructs via the reverse pass. This can be quite difficult &#40;what is the &quot;adjoint&quot; of <code>lock</code>?&#41;, but when worked out this has a few benefits. First of all, you do not have to track values, meaning stack-allocated values can stay on the stack. Additionally, you can JIT compile one backpass because you have a single function used for all backpasses. Lastly, you don&#39;t need to unroll your loops&#33; Instead, which each branch you&#39;d need to insert some data structure to recall the values used from the forward pass &#40;in order to invert in the right directions&#41;. However, that can be much more lightweight than a tracking pass.</p>
<p>This can be a difficult problem to do on a general programming language. In general it needs a strong programmatic representation to use as a compute graph. Google&#39;s engineers did an analysis <a href="https://github.com/tensorflow/swift/blob/master/docs/WhySwiftForTensorFlow.md">when choosing Swift for TensorFlow</a> and narrowed it down to either Swift or Julia due to their internal graph structures. Thus, it should be no surprise that the modern source-to-source AD systems are Zygote.jl for Julia, and Swift for TensorFlow in Swift. Additionally, older AD systems, like Tampenade, ADIFOR, and TAF, all for Fortran, were source-to-source AD systems.</p>
<h2>Derivation of Reverse Mode Rules: Adjoints and Implicit Function Theorem</h2>
<p>In order to require the least amount of work from our AD system, we need to be able to derive the adjoint rules at the highest level possible. Here are a few well-known cases to start understanding. These next examples are from <a href="https://math.mit.edu/~stevenj/18.336/adjoint.pdf">Steven Johnson&#39;s resource</a>.</p>
<h3>Adjoint of Linear Solve</h3>
<p>Let&#39;s say we have the function <span class="math">$A(p)x=b(p)$</span>, i.e. this is the function that is given by the linear solving process, and we want to calculate the gradients of a cost function <span class="math">$g(x,p)$</span>. To evaluate the gradient directly, we&#39;d calculate:</p>
<p class="math">\[
\frac{dg}{dp} = g_p + g_x x_p
\]</p>
<p>where <span class="math">$x_p$</span> is the derivative of each value of <span class="math">$x$</span> with respect to each parameter <span class="math">$p$</span>, and thus it&#39;s an <span class="math">$M \times P$</span> matrix &#40;a Jacobian&#41;. Since <span class="math">$g$</span> is a small cost function, <span class="math">$g_p$</span> and <span class="math">$g_x$</span> are easy to compute, but <span class="math">$x_p$</span> is given by:</p>
<p class="math">\[
x_{p_i} = A^{-1}(b_{p_i}-A_{p_i}x)
\]</p>
<p>and so this is <span class="math">$P$</span> <span class="math">$M \times M$</span> linear solves, which is expensive&#33; However, if we multiply by</p>
<p class="math">\[
\lambda^{T} = g_x A^{-1}
\]</p>
<p>then we obtain</p>
<p class="math">\[
\frac{dg}{dp}\vert_{f=0} = g_p - \lambda^T f_p = g_p - \lambda^T (A_p x - b_p)
\]</p>
<p>which is an alternative formulation of the derivative at the solution value. However, in this case there is no computational benefit to this reformulation.</p>
<h3>Adjoint of Nonlinear Solve</h3>
<p>Now let&#39;s look at some <span class="math">$f(x,p)=0$</span> nonlinear solving. Differentiating by <span class="math">$p$</span> gives us:</p>
<p class="math">\[
f_x x_p + f_p = 0
\]</p>
<p>and thus <span class="math">$x_p = -f_x^{-1}f_p$</span>. Therefore, using our cost function we write:</p>
<p class="math">\[
\frac{dg}{dp} = g_p + g_x x_p = g_p - g_x \left(f_x^{-1} f_p \right)
\]</p>
<p>or</p>
<p class="math">\[
\frac{dg}{dp} = g_p - \left(g_x f_x^{-1} \right) f_p
\]</p>
<p>Since <span class="math">$g_x$</span> is <span class="math">$1 \times M$</span>, <span class="math">$f_x^{-1}$</span> is <span class="math">$M \times M$</span>, and <span class="math">$f_p$</span> is <span class="math">$M \times P$</span>, this grouping changes the problem gets rid of the size <span class="math">$MP$</span> term.</p>
<p>As is normal with backpasses, we solve for <span class="math">$x$</span> through the forward pass however we like, and then for the backpass solve for</p>
<p class="math">\[
f_x^T \lambda = g_x^T
\]</p>
<p>to obtain</p>
<p class="math">\[
\frac{dg}{dp}\vert_{f=0} = g_p - \lambda^T f_p
\]</p>
<p>which does the calculation without ever building the size <span class="math">$M \times MP$</span> term.</p>
<h3>Adjoint of Ordinary Differential Equations</h3>
<p>We with to solve for some cost function <span class="math">$G(u,p)$</span> evaluated throughout the differential equation, i.e.:</p>
<p class="math">\[
G(u,p) = G(u(p)) = \int_{t_0}^T g(u(t,p))dt
\]</p>
<p>To derive this adjoint, introduce the Lagrange multiplier <span class="math">$\lambda$</span> to form:</p>
<p class="math">\[
I(p) = G(p) - \int_{t_0}^T \lambda^\ast (u^\prime - f(u,p,t))dt
\]</p>
<p>Since <span class="math">$u^\prime = f(u,p,t)$</span>, this is the mathematician&#39;s trick of adding zero, so then we have that</p>
<p class="math">\[
\frac{dG}{dp} = \frac{dI}{dp} = \int_{t_0}^T (g_p + g_u s)dt - \int_{t_0}^T \lambda^\ast (s^\prime - f_u s - f_p)dt
\]</p>
<p>for <span class="math">$s$</span> being the sensitivity, <span class="math">$s = \frac{du}{dp}$</span>. After applying integration by parts to <span class="math">$\lambda^\ast s^\prime$</span>, we get that:</p>
<p class="math">\[
\int_{t_{0}}^{T}\lambda^{\ast}\left(s^{\prime}-f_{u}s-f_{p}\right)dt	=\int_{t_{0}}^{T}\lambda^{\ast}s^{\prime}dt-\int_{t_{0}}^{T}\lambda^{\ast}\left(f_{u}s-f_{p}\right)dt
\]</p>
<p class="math">\[
=|\lambda^{\ast}(t)s(t)|_{t_{0}}^{T}-\int_{t_{0}}^{T}\lambda^{\ast\prime}sdt-\int_{t_{0}}^{T}\lambda^{\ast}\left(f_{u}s-f_{p}\right)dt
\]</p>
<p>To see where we ended up, let&#39;s re-arrange the full expression now:</p>
<p class="math">\[
\frac{dG}{dp}	=\int_{t_{0}}^{T}(g_{p}+g_{u}s)dt+|\lambda^{\ast}(t)s(t)|_{t_{0}}^{T}-\int_{t_{0}}^{T}\lambda^{\ast\prime}sdt-\int_{t_{0}}^{T}\lambda^{\ast}\left(f_{u}s-f_{p}\right)dt
\]</p>
<p class="math">\[
=\int_{t_{0}}^{T}(g_{p}+\lambda^{\ast}f_{p})dt+|\lambda^{\ast}(t)s(t)|_{t_{0}}^{T}-\int_{t_{0}}^{T}\left(\lambda^{\ast\prime}+\lambda^\ast f_{u}-g_{u}\right)sdt
\]</p>
<p>That was just a re-arrangement. Now, let&#39;s require that</p>
<p class="math">\[
\lambda^\prime = -\frac{df}{du}^\ast \lambda - \left(\frac{dg}{du} \right)^\ast
\]</p>
<p class="math">\[
\lambda(T) = 0
\]</p>
<p>This means that the boundary term of the integration by parts is zero, and also one of those integral terms are perfectly zero. Thus, if <span class="math">$\lambda$</span> satisfies that equation, then we get:</p>
<p class="math">\[
\frac{dG}{dp} = \lambda^\ast(t_0)\frac{dG}{du}(t_0) + \int_{t_0}^T \left(g_p + \lambda^\ast f_p \right)dt
\]</p>
<p>which gives us our adjoint derivative relation.</p>
<p>If <span class="math">$G$</span> is discrete, then it can be represented via the Dirac delta:</p>
<p class="math">\[
G(u,p) = \int_{t_0}^T \sum_{i=1}^N \Vert d_i - u(t_i,p)\Vert^2 \delta(t_i - t)dt
\]</p>
<p>in which case</p>
<p class="math">\[
g_u(t_i) = 2(d_i - u(t_i,p))
\]</p>
<p>at the data points <span class="math">$(t_i,d_i)$</span>. Therefore, the derivative of an ODE solution with respect to a cost function is given by solving for <span class="math">$\lambda^\ast$</span> using an ODE for <span class="math">$\lambda^T$</span> in reverse time, and then using that to calculate <span class="math">$\frac{dG}{dp}$</span>. Note that <span class="math">$\frac{dG}{dp}$</span> can be calculated simultaniously by appending a single value to the reverse ODE, since we can simply define the new ODE term as <span class="math">$g_p + \lambda^\ast f_p$</span>, which would then calculate the integral on the fly &#40;ODE integration is just... integration&#33;&#41;.</p>
<h3>Complexities of Implementing ODE Adjoints</h3>
<p>The image below explains the dilemma:</p>
<p><img src="https://user-images.githubusercontent.com/1814174/66882662-4d741a00-ef99-11e9-9233-4d6804fec2ec.PNG" alt="" /></p>
<p>Essentially, the whole problem is that we need to solve the ODE</p>
<p class="math">\[
\lambda^\prime = -\frac{df}{du}^\ast \lambda - \left(\frac{dg}{du} \right)^\ast
\]</p>
<p class="math">\[
\lambda(T) = 0
\]</p>
<p>in reverse, but <span class="math">$\frac{df}{du}$</span> is defined by <span class="math">$u(t)$</span> which is a value only computed in the forward pass &#40;the forward pass is embedded within the backpass&#33;&#41;. Thus we need to be able to retrieve the value of <span class="math">$u(t)$</span> to get the Jacobian on-demand. There are three ways which this can be done:</p>
<ol>
<li><p>If you solve the reverse ODE <span class="math">$u^\prime = f(u,p,t)$</span> backwards in time, mathematically it&#39;ll give equivalent values. Computation-wise, this means that you can append <span class="math">$u(t)$</span> to <span class="math">$\lambda(t)$</span> &#40;to <span class="math">$\frac{dG}{dp}$</span>&#41; to calculate all terms at the same time with a single reverse pass ODE. However, numerically this is unstable and thus not always recommended &#40;ODEs are reversible, but ODE solver methods are not necessarily going to generate the same exact values or trajectories in reverse&#33;&#41;</p>
</li>
<li><p>If you solve the forward ODE and receive a continuous solution <span class="math">$u(t)$</span>, you can interpolate it to retrieve the values at any given the time reverse pass needs the <span class="math">$\frac{df}{du}$</span> Jacobian. This is fast but memory-intensive.</p>
</li>
<li><p>Every time you need a value <span class="math">$u(t)$</span> during the backpass, you re-solve the forward ODE to <span class="math">$u(t)$</span>. This is expensive&#33; Thus one can instead use <em>checkpoints</em>, i.e. save at finitely many time points during the forward pass, and use those as starting points for the <span class="math">$u(t)$</span> calculation.</p>
</li>
</ol>
<p>Alternative strategies can be investigated, such as an interpolation which stores values in a compressed form.</p>
<h3>The vjp and Neural Ordinary Differential Equations</h3>
<p>It is here that we can note that, if <span class="math">$f$</span> is a function defined by a neural network, we arrive at the <em>neural ordinary differential equation</em>. This adjoint method is thus the backpropogation method for the neural ODE. However, the backpass</p>
<p class="math">\[
\lambda^\prime = -\frac{df}{du}^\ast \lambda - \left(\frac{dg}{du} \right)^\ast
\]</p>
<p class="math">\[
\lambda(T) = 0
\]</p>
<p>can be improved by noticing <span class="math">$\frac{df}{du}^\ast \lambda$</span> is a vjp, and thus it can be calculated using <span class="math">$\mathcal{B}_f^{u(t)}(\lambda^\ast)$</span>, i.e. reverse-mode AD on the function <span class="math">$f$</span>. If <span class="math">$f$</span> is a neural network, this means that the reverse ODE is defined through successive backpropogation passes of that neural network. The result is a derivative with respect to the cost function of the parameters defining <span class="math">$f$</span> &#40;either a model or a neural network&#41;, which can then be used to fit the data &#40;&quot;train&quot;&#41;.</p>
<h2>Alternative &quot;Training&quot; Strategies</h2>
<p>Those are the &quot;brute force&quot; training methods which simply use <span class="math">$u(t,p)$</span> evaluations to calculate the cost. However, it is worth noting that there are a few better strategies that one can employ in the case of dynamical models.</p>
<h3>Multiple Shooting Techniques</h3>
<p>Instead of shooting just from the beginning, one can instead shoot from multiple points in time:</p>
<p><img src="https://user-images.githubusercontent.com/1814174/66883548-561a1f80-ef9c-11e9-9ce1-0b6b55c950f9.PNG" alt="" /></p>
<p>Of course, one won&#39;t know what the &quot;initial condition in the future&quot; is, but one can instead make that a parameter. By doing so, each interval can be solved independently, and one can then add to the cost function that the end of one interval must match up with the beginning of the other. This can make the integration more robust, since shooting with incorrect parameters over long time spans can give massive gradients which makes it hard to hone in on the correct values.</p>
<h3>Collocation Methods</h3>
<p>If the data is dense enough, one can fit a curve through the points, such as a spline:</p>
<p><img src="https://user-images.githubusercontent.com/1814174/66883762-fc662500-ef9c-11e9-91c7-c445e32d120f.PNG" alt="" /></p>
<p>If that&#39;s the case, one can use the fit spline in order to estimate the derivative at each point. Since the ODE is defined as <span class="math">$u^\prime = f(u,p,t)$</span>, one then then use the cost function</p>
<p class="math">\[
C(p) = \sum_{i=1}^N \Vert\tilde{u}^{\prime}(t_i) - f(u(t_i),p,t)\Vert
\]</p>
<p>where <span class="math">$\tilde{u}^{\prime}(t_i)$</span> is the estimated derivative at the time point <span class="math">$t_i$</span>. Then one can fit the parameters to ensure this holds. This method can be extremely fast since the ODE doesn&#39;t ever have to be solved&#33; However, note that this is not able to compensate for error accumulation, and thus early errors are not accounted for in the later parts of the data. This means that the integration won&#39;t necessarily match the data even if this fit is &quot;good&quot; if the data points are too far apart, a property that is not true with fitting. Thus, this is usually done as part of a <em>two-stage method</em>, where the starting stage uses collocation to get initial parameters which is then completed with a shooting method.</p>


        <HR/>
        <div class="footer">
          <p>
            Published from <a href="adjoints.jmd">adjoints.jmd</a>
            using <a href="http://github.com/JunoLab/Weave.jl">Weave.jl</a> v0.10.6 on 2020-11-25.
          </p>
        </div>
      </div>
    </div>
  </div>
</BODY>

</HTML>
